{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "input_fld = 'input_fld_path'\n",
    "weight_file = 'kerasmodel_file_name located inside input_fld'\n",
    "output_node_names_of_input_network = \"pred0\" #comma separated\n",
    "write_graph_def_ascii_flag = True\n",
    "output_node_names_of_final_network = 'output_node' #comma separated\n",
    "output_graph_name = 'constant_graph_weights.pb'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import load_model\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import os.path as osp\n",
    "\n",
    "output_fld = input_fld + 'tensorflow_model/'\n",
    "if not os.path.isdir(output_fld):\n",
    "    os.mkdir(output_fld)\n",
    "weight_file_path = osp.join(input_fld, weight_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load keras model and rename output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output nodes names are:  ['output_node']\n"
     ]
    }
   ],
   "source": [
    "net_model = load_model(weight_file_path)\n",
    "\n",
    "num_output = len(output_node_names_of_input_network.split(','))\n",
    "pred_node_names = output_node_names_of_final_network.split(',')\n",
    "pred = [None]*num_output\n",
    "for i in range(num_output):\n",
    "    pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])\n",
    "print('output nodes names are: ', pred_node_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['output_node']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred\n",
    "pred_node_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras.models import load_model\n",
    "from tensorflow.python.framework.graph_util import convert_variables_to_constants\n",
    "from tensorflow.python.saved_model import utils\n",
    "import tensorflow as tf\n",
    "from tensorflow.core.framework import graph_pb2\n",
    "from tensorflow.core.protobuf import saver_pb2\n",
    "from tensorflow.python.client import session\n",
    "# from tensorflow.python.framework import graph_io\n",
    "from tensorflow.python.framework import importer\n",
    "from tensorflow.python.framework import ops\n",
    "from tensorflow.python.framework import test_util\n",
    "from tensorflow.python.framework import graph_util\n",
    "from tensorflow.python.ops import math_ops\n",
    "# from tensorflow.python.ops import variable\n",
    "from tensorflow.python.platform import test\n",
    "from tensorflow.python.training import saver as saver_lib\n",
    "from tensorflow.contrib.session_bundle import exporter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "checkpoint_file = output_fld #+ \"checkpoint\"\n",
    "checkpoint_state_name = \"checkpoint_state\"\n",
    "input_graph_name = output_fld + 'only_the_graph_def.pb'\n",
    "output_graph_path = os.path.join(output_fld, output_graph_name)\n",
    "keras_architecture = osp.join(output_fld, 'keras_architecture_json')\n",
    "keras_weights = osp.join(output_fld, 'keras_weights')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "sess = K.get_session()\n",
    "graph = tf.get_default_graph()\n",
    "input_graph_def = graph.as_graph_def()\n",
    "saver = tf.train.Saver()\n",
    "saver.save(sess, checkpoint_file, global_step=0, latest_filename='checkpoint_state')\n",
    "if write_graph_def_ascii_flag:\n",
    "    tf.train.write_graph(input_graph_def, output_fld, 'only_the_graph_def.pb.ascii', as_text=True)\n",
    "tf.train.write_graph(input_graph_def, output_fld, 'only_the_graph_def.pb', as_text=False)\n",
    "writer = tf.summary.FileWriter(output_fld, graph=graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/-0\n",
      "INFO:tensorflow:Froze 27 variables.\n",
      "Converted 27 variables to const ops.\n",
      "541 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "import freeze_graph\n",
    "import os\n",
    "K._LEARNING_PHASE = tf.constant(0)\n",
    "\n",
    "input_saver_def_path = \"\" # deprecated\n",
    "input_binary = True\n",
    "\n",
    "restore_op_name = \"save/restore_all\"\n",
    "filename_tensor_name = \"save/Const:0\"\n",
    "\n",
    "clear_devices = False\n",
    "\n",
    "freeze_graph.freeze_graph(input_graph_name, input_saver_def_path,\n",
    "                          input_binary, checkpoint_file+'-0',\n",
    "                          output_node_names_of_final_network, restore_op_name,\n",
    "                          filename_tensor_name, output_graph_path,\n",
    "                          clear_devices, \"\", None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
